from abc import ABC, abstractmethod
import numpy as np
from .data_helper import str_to_date
import lazy_property

from datetime import timedelta


def state_to_vect(state, d, timetrend):
    payment, priority, action, cov = state
    return 1 * np.array([
        payment > 0, payment, priority == 'G1',
        priority == 'G2', priority == 'G3',
        action == 'valor', action == 'rec1', action == 'medida', cov,
        (action == 'rec1') * (priority == 'G1'),
        (action == 'medida') * (priority == 'G1')] + timetrend * [d])

class NextStateBase(ABC):

    @abstractmethod
    def __call__(self, *args, **kwargs):
        pass

    def makes_payment(self, param, state, theta, d):
        p = self.proba_payment(param, state, theta, d)
        return np.random.choice([0, 1], p=[1-p, p])

    def proba_payment(self, param, state, theta, d):
        vect = state_to_vect(state, d, self.timetrend)
        return 1 - np.exp(-self.s_shape(np.dot(param, vect) + theta))

    @abstractmethod
    def s_shape(self, x):
        pass

    @abstractmethod
    def distribution_payment(self, *args, **kwargs):
        pass


class ConstantPolicyNextStep(NextStateBase):

    def __init__(self, params, a, b, max_payment, timetrend=False):
        self.params = params
        self.a = a
        self.b = b
        self.max_payment = max_payment
        self.timetrend = timetrend

    def __call__(self, state, theta, d):
        payment, priority, action, cov = state
        if self.makes_payment(self.params, state, theta, d):
            payment += self.distribution_payment()
        return (payment, priority, action, cov), theta

    def s_shape(self, x):
        a, b = self.a, self.b
        return min(b-a, max(x-a, 1e-3))

    def distribution_payment(self):
        return self.max_payment * np.random.rand(1)[0]


class VectorNextStepBase(NextStateBase):

    def __init__(self, params, map_payments, list_tax_due=None, list_rel_pay=None,
                 timetrend=False, enriched_payments=False,max_payment_threshold=4,
                 options=None):
        self._suggested_options = options
        self.sigma_sdev = params[-1]
        self.a, self.b = params[-3:-1]
        self.params = params[:-3]
        self.map_payments = map_payments
        self.enriched_payments = enriched_payments
        self.max_payment_threshold = max_payment_threshold
        if isinstance(list(map_payments.keys())[0], tuple):
            if enriched_payments and len(list(map_payments.keys())[0]) != 2:
                raise ValueError('Enriched payments requires a 2D tuple')
            if (not enriched_payments):
                raise ValueError('Enriched payments should not be a tuple')
        else:
            if enriched_payments:
                raise ValueError('Enriched payments key should be a tuple')
        self.list_tax_due = list_tax_due if list_tax_due is not None else \
            [99, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1250, 1500, 2000,
             60000]
        self.list_rel_pay = list_rel_pay if list_rel_pay is not None else [-.1, 0.5, 540]
        self.timetrend = timetrend

    @lazy_property.LazyProperty
    def options(self):
        return self.set_default_options(self._suggested_options)

    @staticmethod
    def set_default_options(options):
        options = {} if options is None else options
        if 's_shape' not in options:
            options['s_shape'] = 'linear'
        if 'param_transform' not in options:
            options['param_transform'] = lambda x: x
        return options

    def __call__(self, array_total_due_state_type, time):
        a = self.update_payments(array_total_due_state_type, time)
        if time not in self.no_priority_update_weeks:
            a = self.update_priorities(a, time)
        a = self.update_actions(a)
        return a

    def update_priorities(self, *args, **kwargs):
        return None

    def update_actions(self, *args, **kwargs):
        return None

    def update_payments(self, a, d):
        return np.apply_along_axis(self._update_payments, 1, a, d)

    def _update_payments(self, total_due_state_type, d):
        total_due, state, theta = self._split(total_due_state_type)
        incr_payment = 0
        normalized_state = state.copy()
        normalized_state[0] = np.divide(state[0], total_due)
        potential_payment = self.distribution_payment(total_due, payment=normalized_state[0])
        if self.makes_payment(self.params, normalized_state, theta, d):
            incr_payment = potential_payment
        if normalized_state[0] > self.max_payment_threshold:
            incr_payment = 0
        total_due_state_type[1] += incr_payment
        return total_due_state_type

    def _split(self, total_due_state_type):
        total_due = total_due_state_type[0]
        theta = total_due_state_type[-1]
        state = total_due_state_type[1:-1]
        return total_due, state, theta

    def s_shape(self, x):
        a, b = self.a, self.b
        if self.options['s_shape'] == 'logistic':
            return np.divide(b, 1 + np.exp(-x + a))
        return min(b-a, max(x-a, 1e-3))

    def distribution_payment(self, total_due, payment=0):
        if not self.enriched_payments:
            try:
                self.map_payments[self.tax_bracket(total_due)]
            except:
                raise ValueError('Due bracket not found')
            return total_due * np.random.choice(
                self.map_payments[self.tax_bracket(total_due)])
        elif self.enriched_payments:
            try:
                self.map_payments[self.tax_bracket(total_due), self.pay_bracket(payment)]
            except:
                raise ValueError('Due or payment bracket not found')
            return total_due * np.random.choice(
                self.map_payments[self.tax_bracket(total_due), self.pay_bracket(payment)])

    def tax_bracket(self, tax_due):
        for i, t in enumerate(self.list_tax_due[:-1]):
            if self.list_tax_due[i + 1] > tax_due:
                return t

    def pay_bracket(self, payment):
        for i, t in enumerate(self.list_rel_pay[:-1]):
            if self.list_rel_pay[i + 1] > payment:
                return t


class CBS:
    total_due = 0
    payment = 1
    priority = 2
    action = 3
    cov = 5
    theta = 6


_order_actions = dict(zip(['N', 'valor', 'rec1', 'medida'], range(4)))


def order_actions(a):
    return _order_actions[a]


class TargetPrioritiesTargetActions(VectorNextStepBase):

    def __init__(self, params, map_payments, n,
                 start_date='2021-04-12', config=None,
                 list_tax_due=None, list_rel_pay=None, timetrend=False,
                 enriched_payments=False, max_payment_threshold=4, options=None):
        super().__init__(params, map_payments, list_tax_due, list_rel_pay,
                         timetrend, enriched_payments,
                         max_payment_threshold, options)
        config = config if config is not None else dict(
            target_G1=400, in_payment_threshold=.5, target_G2=400,
            floor_valor=0, floor_rec1=0, weeks_valor=1, weeks_rec1=4,
            weeks_medida=6, weeks_floor_valor=6, weeks_floor_rec1=12,
            max_flow_G1=1e10, max_flow_G1_change=1e10, target_G1_change=600,
            target_G1_change_week=7)
        self.target_G1 = config['target_G1']
        self.target_G1_change = config['target_G1_change']
        self.target_G1_change_week = config['target_G1_change_week']
        self.target_G2 = config['target_G2']
        self.floor_valor = config['floor_valor']
        self.floor_rec1 = config['floor_rec1']
        self.weeks_valor = config['weeks_valor']
        self.weeks_rec1 = config['weeks_rec1']
        self.weeks_medida = config['weeks_medida']
        self.weeks_floor_valor = config['weeks_floor_valor']
        self.weeks_floor_rec1 = config['weeks_floor_rec1']
        self.in_payment_threshold = config['in_payment_threshold']
        self.max_flow_G1 = config['max_flow_G1']
        self.max_flow_G1_change = config['max_flow_G1_change']
        if 'no_priority_update_weeks' in config.keys():
            self.no_priority_update_weeks = config['no_priority_update_weeks']
        else:
            self.no_priority_update_weeks = {}
        self.start_date = str_to_date(start_date)
        self.current_date = self.start_date
        self.priority_date = np.array(n * [self.start_date])
        self.n = n
        self.one_week = timedelta(days=7)
        self.timetrend = timetrend

    def __call__(self, array_total_due_state_type, d):
        self._update_current_date()
        a = super().__call__(array_total_due_state_type, d)
        return a

    def reset(self):
        self.current_date = self.start_date

    def _update_current_date(self):
        self.current_date += self.one_week

    def update_priorities(self, array_total_due_state_type, time):
        initial_array = np.copy(array_total_due_state_type)
        in_payment = self._is_in_payment(array_total_due_state_type)
        is_medida = array_total_due_state_type[:, CBS.action] == 'medida'
        is_G1 = array_total_due_state_type[:, CBS.priority] == 'G1'
        remaining_g1 = ~in_payment & ~is_medida & is_G1
        candidate_g1 = ~in_payment & ~is_G1
        if time < self.target_G1_change_week:
            feasible_g1 = np.cumsum(candidate_g1) <= min(self.target_G1 - sum(remaining_g1), self.max_flow_G1)
        else:
            feasible_g1 = np.cumsum(candidate_g1) <= min(self.target_G1_change - sum(remaining_g1),
                                                         self.max_flow_G1_change)
        array_total_due_state_type[
            candidate_g1 & feasible_g1, CBS.priority] = 'G1'
        is_G1 = array_total_due_state_type[:, CBS.priority] == 'G1'
        candidate_g2 = (~in_payment) & (~feasible_g1) & (~remaining_g1) & (~is_medida)
        feasible_g2 = np.cumsum(candidate_g2) <= self.target_G2
        array_total_due_state_type[
            candidate_g2 & feasible_g2, CBS.priority] = 'G2'

        self._update_time_at_priority(
            array_total_due_state_type, initial_array)
        return array_total_due_state_type

    def _is_in_payment(self, array_total_due_state_type):
        money = array_total_due_state_type[:, [0, 1]].astype(float)
        relative_payment = np.divide(money[:, 1], money[:, 0])
        return relative_payment >= self.in_payment_threshold

    def _update_time_at_priority(self, a1, a0):
        new_priority = a1[:, CBS.priority] != a0[:, CBS.priority]
        self.priority_date[new_priority] = self.current_date

    @property
    def weeks_spent_at_priority(self):
        return np.apply_along_axis(
            lambda d: d/self.one_week, axis=0,
            arr=self.current_date - self.priority_date)

    def update_actions(self, array_total_due_state_type):
        weeks_priority_act = self._weeks_priority_act(
            array_total_due_state_type)
        in_payment = self._is_in_payment(array_total_due_state_type)
        actions = np.array([
            self._assign_actions(r) for r in weeks_priority_act[~in_payment]])
        array_total_due_state_type[~in_payment, CBS.action] = actions
        array_total_due_state_type = self._apply_action_floor(
            array_total_due_state_type)
        return array_total_due_state_type

    def _weeks_priority_act(self, array_total_due_state_type):
        return np.row_stack(
            (self.weeks_spent_at_priority,
             array_total_due_state_type[:, CBS.priority],
             array_total_due_state_type[:, CBS.action])).T

    def _assign_actions(self, r):
        time_at_priority, priority, action = r
        if priority != 'G1':
            return action
        if time_at_priority < self.weeks_valor:
            return action
        if self.weeks_valor <= time_at_priority < self.weeks_rec1:
            return max(['valor', action], key=order_actions)
        if self.weeks_rec1 <= time_at_priority < self.weeks_medida:
            return max(['rec1', action], key=order_actions)
        if self.weeks_medida <= time_at_priority:
            return 'medida'

    def _apply_action_floor(self, array_total_due_state_type):
        is_valor = array_total_due_state_type[:, CBS.action] == 'valor'
        is_rec1 = array_total_due_state_type[:, CBS.action] == 'rec1'
        is_N = array_total_due_state_type[:, CBS.action] == 'N'
        is_medida = array_total_due_state_type[:, CBS.action] == 'medida'
        in_payment = self._is_in_payment(array_total_due_state_type)
        share_valor = sum(~is_N) / self.floor_valor
        share_rec1 = sum(is_rec1 + is_medida) / self.floor_rec1
        if share_valor < self.target_share_valor:
            num_to_fill = np.floor(
                (self.target_share_valor - share_valor) * self.floor_valor)
            candidates = is_N & (~in_payment)
            feasible = np.cumsum(candidates) <= num_to_fill
            array_total_due_state_type[feasible & candidates, CBS.action] = \
                'valor'
        if share_rec1 < self.target_share_rec1:
            num_to_fill = np.floor(
                (self.target_share_rec1 - share_rec1) * self.floor_rec1)
            candidates = is_valor & (~in_payment)
            feasible = np.cumsum(candidates) <= num_to_fill
            array_total_due_state_type[feasible & candidates, CBS.action] = \
                'rec1'
        return array_total_due_state_type

    @property
    def target_share_valor(self):
        s = (self.current_date - self.start_date) / (
                self.weeks_floor_valor * self.one_week)
        return min(s, 1)

    @property
    def target_share_rec1(self):
        s = (self.current_date - self.start_date
             - self.weeks_floor_valor * self.one_week) / (
                self.one_week * (self.weeks_floor_rec1
                                 - self.weeks_floor_valor))
        return min(1, max(s, 0))

class CounterfactualConfigs:
    def __init__(self, max_flow_G1=41, max_flow_G1_change=41, initial_G1=240):
        self.max_flow_G1 = max_flow_G1
        self.max_flow_G1_change = max_flow_G1_change
        self.initial_G1 = initial_G1

        self.base_config = dict(target_G1=400, in_payment_threshold=.5, target_G2=246,
                                floor_valor=1, floor_rec1=1, weeks_valor=2, weeks_rec1=4,
                                weeks_medida=6, weeks_floor_valor=4, weeks_floor_rec1=8,
                                max_flow_G1=max_flow_G1, max_flow_G1_change=max_flow_G1_change,
                                target_G1_change=400, target_G1_change_week=7)

        self.control_config = dict(target_G1=1e10, in_payment_threshold=.5, target_G2=1e10,
                                   floor_valor=4550, floor_rec1=3450, weeks_valor=6, weeks_rec1=10,
                                   weeks_medida=11, target_G1_change=1e10, target_G1_change_week=7,
                                   weeks_floor_valor=10, weeks_floor_rec1=16,
                                   max_flow_G1=119, max_flow_G1_change=0)

        self.more_rec1_config = self.base_config.copy()
        self.more_rec1_config['floor_valor'] = 1600
        self.more_rec1_config['floor_rec1'] = 1300
        self.more_rec1_config['weeks_floor_valor'] = 1
        self.more_rec1_config['weeks_floor_rec1'] = 2

        self.more_rec1_control_config = self.control_config.copy()
        self.more_rec1_control_config['floor_valor'] = 1600
        self.more_rec1_control_config['floor_rec1'] = 1300
        self.more_rec1_control_config['weeks_floor_valor'] = 1
        self.more_rec1_control_config['weeks_floor_rec1'] = 2

        self.matching_rec1_config = self.base_config.copy()
        self.matching_rec1_config['floor_valor'] = self.control_config['floor_valor']
        self.matching_rec1_config['floor_rec1'] = self.control_config['floor_rec1']
        self.matching_rec1_config['weeks_floor_valor'] = self.control_config['weeks_floor_valor']
        self.matching_rec1_config['weeks_floor_rec1'] = self.control_config['weeks_floor_rec1']

        self.more_G1_config = self.base_config.copy()
        self.more_G1_config['max_flow_G1'] = 1e10
        self.more_G1_config['max_flow_G1_change'] = 1e10
        self.more_G1_config['target_G2'] = 400

        self.more_G1_matching_rec1 = self.matching_rec1_config.copy()
        self.more_G1_matching_rec1['max_flow_G1'] = 1e10
        self.more_G1_matching_rec1['max_flow_G1_change'] = 1e10
        self.more_G1_matching_rec1['target_G2'] = 400

        self.shrinkingG1_config = self.base_config.copy()
        self.shrinkingG1_config['max_flow_G1'] = 36
        self.shrinkingG1_config['max_flow_G1_change'] = 36
        self.shrinkingG1_config['target_G2'] = 216

        self.shorter_G1_config = self.base_config.copy()
        self.shorter_G1_config['weeks_valor'] = 0
        self.shorter_G1_config['weeks_rec1'] = 2
        self.shorter_G1_config['weeks_medida'] = 4

        self.longer_G1_config = self.base_config.copy()
        self.longer_G1_config['weeks_rec1'] = 4
        self.longer_G1_config['weeks_medida'] = 8

        self.longer_G1_more_G1 = self.longer_G1_config.copy()
        self.longer_G1_more_G1['max_flow_G1'] = 1e10
        self.longer_G1_more_G1['max_flow_G1_change'] = 1e10
        self.longer_G1_more_G1['target_G2'] = 400
        self.longer_G1_more_G1['target_G1_change'] = 600
        self.longer_G1_more_G1['target_G1_change_week'] = 7

        self.longer_G1_more_G1_conservative = self.longer_G1_config.copy()
        self.longer_G1_more_G1_conservative['max_flow_G1'] = 1e10
        self.longer_G1_more_G1_conservative['max_flow_G1_change'] = 1e10
        self.longer_G1_more_G1_conservative['target_G2'] = 400
        self.longer_G1_more_G1_conservative['target_G1'] = 400
        self.longer_G1_more_G1_conservative['target_G1_change'] = 400
        self.longer_G1_more_G1_conservative['target_G1_change_week'] = 7


        self.longer_G1_more_G1_actual = self.longer_G1_more_G1.copy()
        self.longer_G1_more_G1_actual['no_priority_update_weeks'] = {7,8,11,15}
        self.longer_G1_more_G1_actual['weeks_medida'] = 9

        self.longer_G1_more_rec1_config = self.more_rec1_config.copy()
        self.longer_G1_more_rec1_config['weeks_rec1'] = 4
        self.longer_G1_more_rec1_config['weeks_medida'] = 8

        self.longer_G1_matching_rec1_config = self.matching_rec1_config.copy()
        self.longer_G1_matching_rec1_config['weeks_rec1'] = 4
        self.longer_G1_matching_rec1_config['weeks_medida'] = 8

        self.longer_G1_more_G1_matching_rec1 = self.more_G1_matching_rec1.copy()
        self.longer_G1_more_G1_matching_rec1['weeks_rec1'] = 4
        self.longer_G1_more_G1_matching_rec1['weeks_medida'] = 8

        self.longer_G1_more_G1_more_rec1 = self.longer_G1_more_rec1_config.copy()
        self.longer_G1_more_G1_more_rec1['max_flow_G1'] = 1e10
        self.longer_G1_more_G1_more_rec1['max_flow_G1_change'] = 1e10
        self.longer_G1_more_G1_more_rec1['target_G2'] = 400


        self.implementation_2022 = dict(target_G1=500, in_payment_threshold=.5, target_G2=500,
                                       floor_valor=1, floor_rec1=1, weeks_valor=2, weeks_rec1=4,
                                       weeks_medida=7, weeks_floor_valor=4, weeks_floor_rec1=8,
                                       max_flow_G1=0, max_flow_G1_change=500,
                                       target_G1_change=500, target_G1_change_week=8,
                                       no_priority_update_weeks=set(range(9, 23)))

        self.implementation_2022_more_rec1 = self.implementation_2022.copy()
        self.implementation_2022_more_rec1['floor_valor'] = 1600
        self.implementation_2022_more_rec1['floor_rec1'] = 1300
        self.implementation_2022_more_rec1['weeks_floor_valor'] = 1
        self.implementation_2022_more_rec1['weeks_floor_rec1'] = 2

        self.implementation_2022_matching_rec1 = self.implementation_2022.copy()
        self.implementation_2022_matching_rec1['floor_valor'] = self.control_config['floor_valor']
        self.implementation_2022_matching_rec1['floor_rec1'] = self.control_config['floor_rec1']
        self.implementation_2022_matching_rec1['weeks_floor_valor'] = self.control_config['weeks_floor_valor']
        self.implementation_2022_matching_rec1['weeks_floor_rec1'] = self.control_config['weeks_floor_rec1']
        
        self.single_round_implementation_2022 = self.implementation_2022_matching_rec1.copy()
        self.single_round_implementation_2022['weeks_medida'] = 10
        self.single_round_implementation_2022['target_G1_change'] = 0
